import cv2
import pandas as pd
from torch.utils.data import Dataset
from torchvision import transforms


class MyDataset(Dataset):

    def __getitem__(self, index):
        x_train = self.transformer(cv2.imread(self.df['filepath'][index]))
        # x_train = self.transformer(Image.open(self.df['filepath'][index]).convert('RGB'))
        y_train = self.df['label'][index]
        return x_train, y_train

    def __init__(self, dataset_dir, csv_path, resize_shape) -> None:
        super().__init__()
        self.dataset_dir = dataset_dir
        self.csv_path = csv_path
        self.shape = resize_shape
        self.df = pd.read_csv(self.csv_path, encoding='utf-8')
        self.transformer = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(self.shape)
        ])

    def __len__(self):
        return len(self.df)
